#=
计算自动微分，优化等机器学习相关的内容
=#

using ForwardDiff
using Enzyme


"""
存储一次的结果
"""
struct AutoJacTape
    input::Vector{Float64}
    output::Vector{Float64}
    kargs::Dict{Any, Any}
end


"""
创建可以快速访问的结果
"""
function AutoJacTape(input, output; kargs...)
    return AutoJacTape(input, vec(output), kargs)
end


"""
！！注意副作用，会定义funcforward
"""
macro autojac(ehkslice, hscfg, bi, xi, U, ny, nz)
    esc(quote
        funcforward = (n) -> bmat_IsingADX($ehkslice, Val($hscfg[$bi, $xi]), $U, 0.0, n[1], n[2])
        reshape(ForwardDiff.jacobian(funcforward, [$ny[$xi], $nz[$xi]]), 4, 2*length($ny), 2)
    end)
end


"""
！！注意副作用，会定义funcforward
"""
macro autojac2(ehkslice, hscfg, bi, xi, U, ny, nz)
    esc(quote
        funcforward = (ehks, n) -> bmat_IsingADX(ehks, Val($hscfg[$bi, $xi]), $U, 0.0, n[1], n[2])
        cat(autodiff(Forward, funcforward, BatchDuplicated, Const($ehkslice),
        BatchDuplicated([$ny[$xi], $nz[$xi]], ([1.0, 0.0], [0.0, 1.0])))[2]...; 
        dims=3)
        #funcforward = (n) -> bmat_IsingADX($ehkslice, Val($hscfg[$bi, $xi]), $U, 0.0, n[1], n[2])
        #autodiff(Forward, funcforward, BatchDuplicated,
        #BatchDuplicated([$ny[$xi], $nz[$xi]], ([1.0, 0.0], [0.0, 1.0])))
    end)
end


struct AdamOpt end
const Adam = AdamOpt()

"""
优化
"""
function next(::AdamOpt, g::Vector{Float64}, m::Vector{Float64}, v::Vector{Float64},
    t::Int; α=0.001, β1=0.9, β2=0.999, ϵ=1e-8)
    #
    tt = t+1
    g2 = .^(g, 2)
    mt = β1*m .+ (1-β1)*g
    vt = β2*v .+ (1-β2)*g2
    hmt = mt / (1-β1^(tt))
    hvt = vt / (1-β2^(tt))
    gt = @. α*hmt/(sqrt(hvt)+ϵ)
    return gt, mt, vt, tt
end


"""
注意副作用，生成无导数的自定义函数
"""
macro fd_detach_(expr)
    esc(quote
        $(expr)(x::ForwardDiff.Dual{Z,T,N}) where {Z,T,N} =
        begin
            outv = $(expr)(ForwardDiff.value(x))
            [ForwardDiff.Dual{Z}(v,
            ntuple(i->0.0*ForwardDiff.partials(x,i), N)) for v in outv]
        end
        #
        $(expr)(xs::Vector{ForwardDiff.Dual{Z,T,N}}) where {Z,T,N} =
        begin
            input = map(x_ -> ForwardDiff.value(x_), xs)
            outv = $(expr)(input)
            ret = Vector{ForwardDiff.Dual{Z,T,N}}(undef, length(outv))
            for ri in Base.OneTo(length(outv))
                #der = (f'_1 * partials(xs[1], 1)+f'_2 * partials(xs[2], 1),
                #       f'_1 * partials(xs[1], 2)+f'_2 * partials(xs[2], 2))
                #tuple里的第i个和第i次调用或者vec里的第i个有关
                ret[ri] = ForwardDiff.Dual{Z}(outv[ri], ntuple(i->0.0*ForwardDiff.partials(xs[i], i), N))
            end
            ret
        end
    end)
end

#=mtest.jl
include("./src/qmcmary.jl")
using ..qmcmary
using ForwardDiff

f(x) = [sum(x.^2 .- x .+ 2), sum(x.^2 .- 2x .+ 2), sum(x.^2 .-2x .+ 2)]


println(f(1.0), " ", ForwardDiff.derivative(f, 2.0))

qmcmary.@fd_detach_(f)
#f(x::ForwardDiff.Dual{Z,T,N}) where {Z,T,N} =
#begin
#    outv = f(ForwardDiff.value(x))
#    println(typeof(x))
#    lenv = length(outv)
#    [ForwardDiff.Dual{Z}(v, ntuple(i->0.0*ForwardDiff.partials(x, i), N)) for v in outv]
#end

#f(xs::Vector{ForwardDiff.Dual{Z,T,N}}) where {Z,T,N} =
#begin
#    for x in xs
#        println(typeof(x), ForwardDiff.value(x))
#        #push!(ret, ForwardDiff.Dual{Z,T,N}(outv[1], ForwardDiff.Partials{N,T}((1.0, 0.0))))
#    end
#    input = map(x_ -> ForwardDiff.value(x_), xs)
#    println(input)
#    outv = f(input)
#    println(outv)
#    ret = Vector{ForwardDiff.Dual{Z,T,N}}(undef, length(outv))
#    for ri in Base.OneTo(length(outv))
#        #der = (f'_1 * partials(xs[1], 1)+f'_2 * partials(xs[2], 1),
#        #       f'_1 * partials(xs[1], 2)+f'_2 * partials(xs[2], 2))
#        #tuple里的第i个和第i次调用或者vec里的第i个有关
#        ret[ri] = ForwardDiff.Dual{Z}(outv[ri], ntuple(i->1.0*ForwardDiff.partials(xs[i], i), N))
#    end
#    ret
#end


fw = ForwardDiff.derivative(f, 2.0)
println(f(1.0), " ", fw)
fw = ForwardDiff.jacobian(f, [2.0, 3.0])
println(f(1.0), " ", fw)
=#
